Source code for hysop.backend.device.opencl.opencl_types

# Copyright (c) HySoP 2011-2024
#
# This file is part of HySoP software.
# See "https://particle_methods.gricad-pages.univ-grenoble-alpes.fr/hysop-doc/"
# for further info.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import string, re
import sympy as sm
import numpy as np
import itertools as it

from hysop import __KERNEL_DEBUG__, vprint, dprint
from hysop.backend.device.opencl import cl, clArray, clTypes
from hysop.tools.numerics import MPZ, MPQ, MPFR, F2Q
from hysop.tools.htypes import first_not_None, to_tuple

vsizes = [1, 2, 3, 4, 8, 16]
base_types = ["float", "signed", "unsigned"]
float_base_types = ["half", "float", "double"]
signed_base_types = ["char", "short", "int", "long"]
unsigned_base_types = ["uchar", "ushort", "uint", "ulong"]

float_types = []
signed_types = []
unsigned_types = []
for b in base_types:
    b_base_types = eval(b + "_base_types")
    b_types = eval(b + "_types")
    for f, c in it.product(b_base_types, vsizes):
        if c == 1:
            if f == "half":
                continue
            else:
                ftype = f
        else:
            ftype = f + str(c)
        b_types.append(ftype)
integer_types = signed_types + unsigned_types
builtin_types = integer_types + float_types


float_base_type_require = {
    "half": "cl_khr_fp16",
    "float": None,
    "double": "cl_khr_fp64",
}

FLT_DIG = {
    "half": 3,  # = HALF_DIG
    "float": 6,  # =  FLT_DIG
    "double": 15,  # =  DBL_DIG
}
FLT_MANT_DIG = {
    "half": 11,  # = HALF_MANT_DIG
    "float": 24,  # =  FLT_MANT_DIG
    "double": 53,  # =  DBL_MANT_DIG
}
FLT_LITERAL = {"half": "h", "float": "f", "double": ""}
FLT_BYTES = {"half": 2, "float": 4, "double": 8}


[docs] def basetype(fulltype): return fulltype.translate(str.maketrans("", "", string.digits))
[docs] def components(fulltype): comp = fulltype.translate(str.maketrans("", "", string.ascii_letters + "_")) return 1 if comp == "" else int(comp)
[docs] def mangle_vtype(fulltype): return basetype(fulltype)[0] + str(components(fulltype))
[docs] def vtype(basetype, N): return basetype + ("" if N == 1 else str(N))
[docs] def itype(fulltype): N = components(fulltype) return "int" + ("" if N == 1 else str(N))
[docs] def uitype(fulltype): N = components(fulltype) return "uint" + ("" if N == 1 else str(N))
[docs] def np_dtype(fulltype): return cl.tools.get_or_register_dtype(fulltype)
[docs] def vtype_component_adressing(i, mode="hex"): if mode == "hex": return "0123456789abcdef"[i] elif mode == "HEX": return "0123456789ABCDEF"[i] elif mode == "pos": return "xyzw"[i] else: raise ValueError("Bad vtype component adressing mode!")
[docs] def vtype_access(i, N, mode="hex"): assert i < N if N == 1: return "" else: return ("s" if mode.lower() == "hex" else "") + vtype_component_adressing( i, mode )
[docs] def float_to_hex_str(f, fbtype): if f != f: return "NAN" sf = float(f).hex().split("0x") + [""] buf = sf[1].split("p") mantissa = buf[0] exponent = buf[1] mant_dig = FLT_MANT_DIG[fbtype] literal = FLT_LITERAL[fbtype] nhex = (mant_dig - 1 + 3) // 4 + 2 # +2= leading one or zero and decimal point characters (1.abde... or 0.abcde...) sf[0] = ("+" if sf[0] == "" else sf[0]) + "0x" sf[1] = mantissa[:nhex] sf[2] = "p" + exponent + literal return "".join(sf)
[docs] def float_to_dec_str(f, fbtype): """ sf = (sign, mantissa, exponent) """ if f != f: return "NAN" sf = float(f).__repr__().split(".") if len(sf) == 1: return sf[0] sf += (3 - len(sf)) * [None] buf = sf[1].split("e") mantissa = buf[0] exponent = buf[1] if len(buf) > 1 else None dig = FLT_DIG[fbtype] literal = FLT_LITERAL[fbtype] sig = len(sf[0].replace("+", "").replace("-", "").lstrip("0")) sf[0] = "+" if (sf[0] == "") else sf[0] + "." sf[1] = mantissa[: dig - sig + 1] sf[2] = "e" + exponent + literal if (exponent is not None) else literal return "".join(sf)
# pyopencl specific vec = clTypes
[docs] def npmake(dtype): return lambda scalar: dtype(scalar) # np.array([scalar], dtype=dtype)
vtype_int = [np.int32, vec.int2, vec.int3, vec.int4, vec.int8, vec.int16] vtype_uint = [np.uint32, vec.uint2, vec.uint3, vec.uint4, vec.uint8, vec.uint16] vtype_simple = [np.float32, vec.float2, vec.float3, vec.float4, vec.float8, vec.float16] vtype_double = [ np.float64, vec.double2, vec.double3, vec.double4, vec.double8, vec.double16, ] cl_vec_types = vtype_int + vtype_uint + vtype_simple + vtype_double make_int = [ npmake(np.int32), vec.make_int2, vec.make_int3, vec.make_int4, vec.make_int8, vec.make_int16, ] make_uint = [ npmake(np.uint32), vec.make_uint2, vec.make_uint3, vec.make_uint4, vec.make_uint8, vec.make_uint16, ] make_simple = [ npmake(np.float32), vec.make_float2, vec.make_float3, vec.make_float4, vec.make_float8, vec.make_float16, ] make_double = [ npmake(np.float64), vec.make_double2, vec.make_double3, vec.make_double4, vec.make_double8, vec.make_double16, ]
[docs] def simplen(n): if n == 1: return np.float32 i = vsizes.index(n) return vtype_simple[i]
[docs] def doublen(n): if n == 1: return np.float64 i = vsizes.index(n) return vtype_double[i]
[docs] def intn(n): if n == 1: return np.int32 i = vsizes.index(n) return vtype_int[i]
[docs] def uintn(n): if n == 1: return np.uint32 i = vsizes.index(n) return vtype_uint[i]
_typen = { "float": simplen, "simple": simplen, "double": doublen, "int": intn, "uint": uintn, }
[docs] def typen(btype, n): return _typen[btype](n)
[docs] def make_simplen(vals, n, dval=0): vals = to_tuple(vals) vals += (dval,) * (n - len(vals)) i = vsizes.index(n) return make_simple[i](*vals)
[docs] def make_doublen(vals, n, dval=0): vals = to_tuple(vals) vals += (dval,) * (n - len(vals)) i = vsizes.index(n) return make_double[i](*vals)
[docs] def make_intn(vals, n, dval=0): vals = to_tuple(vals) vals += (dval,) * (n - len(vals)) i = vsizes.index(n) return make_int[i](*vals)
[docs] def make_uintn(vals, n, dval=0): vals = to_tuple(vals) vals += (dval,) * (n - len(vals)) i = vsizes.index(n) return make_uint[i](*vals)
_make_typen = { "float": make_simplen, "simple": make_simplen, "double": make_doublen, "int": make_intn, "uint": make_uintn, }
[docs] def make_typen(btype): return _make_typen[btype]
[docs] def cl_type_to_dtype(cl_type): btype = basetype(cl_type) N = components(cl_type) return typen(btype, N)
[docs] def cl_vec_type_to_scalar_and_count(cl_vec_type): assert cl_vec_type in cl_vec_types cvt = cl_vec_type for vtypes in (vtype_int, vtype_uint, vtype_simple, vtype_double): if cvt in vtypes: btype = vtypes[0] count = vsizes[vtypes.index(cvt)] return (btype, count) msg = "cl_vec_types != U(vtype_*)" raise RuntimeError(msg)
[docs] class TypeGen: def __init__(self, fbtype="float", float_dump_mode="dec"): self.float_base_types = float_base_types self.FLT_BYTES = FLT_BYTES self.FLT_DIG = FLT_DIG self.FLT_MANT_DIG = FLT_MANT_DIG self.FLT_LITERAL = FLT_LITERAL self.np_dtype = np_dtype self.float_to_dec_str = float_to_dec_str self.float_to_hex_str = float_to_hex_str self.fbtype = fbtype self.float_dump_mode = float_dump_mode if float_dump_mode in ["hex", "hexadecimal"]: self.float_to_str = float_to_hex_str elif float_dump_mode in ["dec", "decimal"]: self.float_to_str = float_to_dec_str else: raise ValueError(f"Unknown float_dump_mode '{float_dump_mode}'")
[docs] def dump(self, val): if isinstance(val, (list, tuple, dict, np.ndarray)): if isinstance(val, (list, tuple)) and len(val) == 1: val = val[0] elif isinstance(val, np.ndarray) and val.size == 1: val = val.flatten()[0] else: raise ValueError(f"Value is not a scalar, got {val}.") if isinstance(val, (float, np.floating, MPFR, sm.Float)): sval = self.float_to_str(val, self.fbtype) return f"({sval})" elif isinstance(val, (np.integer, int, MPZ, sm.Integer)): suffix = "" if isinstance(val, np.unsignedinteger): suffix += "u" if isinstance(val, (np.int64, np.uint64, MPZ)): suffix += "L" sign = "" if val == 0 else ("+" if val > 0 else "-") sval = str(val) if val < 0: sval = sval[1:] if val != 0: sval = f"({sign}{sval}{suffix})" else: sval = "0" + suffix return sval elif isinstance(val, (bool, np.bool_)): return "true" if val else "false" elif isinstance(val, (MPQ, sm.Rational)): if not __KERNEL_DEBUG__: return self.dump(float(val)) if isinstance(val, MPQ): if val.denominator == 1: return str(val.numerator) else: return "({}.0{f}/{}.0{f})".format( val.numerator, val.denominator, f=FLT_LITERAL[self.fbtype] ) elif isinstance(val, sm.Rational): if val.q == 1: return str(val.p) else: val = "({}.0{f}/{}.0{f})".format( val.p, val.q, f=FLT_LITERAL[self.fbtype] ) return val else: assert False elif isinstance(val, str): return val else: # msg='Unknown value type {}.\n__mro__ is:\n *{}'.format(type(val), '\n *'.join(str(x) for x in type(val).__mro__)) # raise NotImplementedError(msg) return str(val)
[docs] def dumped_type(self, val): if isinstance(val, (list, tuple, dict, np.ndarray)): if isinstance(val, (list, tuple)) and len(val) == 1: val = val[0] elif isinstance(val, np.ndarray) and val.size == 1: val = val.flatten()[0] else: raise ValueError(f"Value is not a scalar, got {val}.") if isinstance(val, (float, np.floating, MPFR, sm.Float)): return self.fbtype elif isinstance(val, (np.integer, int, MPZ, sm.Integer)): if isinstance(val, (np.int64, MPZ)): return "long" elif isinstance(val, np.uint64): return "ulong" elif isinstance(val, np.unsignedinteger): return "uint" elif isinstance(val, int): return "long" else: return "int" elif isinstance(val, (bool, np.bool_)): return "bool" elif isinstance(val, (MPQ, sm.Rational)): return self.fbtype else: return None
# struct type generation (type size and struct field offsets) is different for each device # depending on architecture and compiler implementation and features. # /!\ do not use the same opencl typegen instance for two different devices that are # not equivalent.
[docs] class OpenClTypeGen(TypeGen):
[docs] @staticmethod def devicelessTypegen(): """ Sometimes we do not need structs and code generation is device independent. """ return OpenClTypeGen(device=None, context=None, platform=None)
def __init__( self, device, context, platform, fbtype="float", float_dump_mode="dec", use_short_circuit_ops=False, unroll_loops=False, ): super().__init__(fbtype, float_dump_mode) self.device = device self.context = context self.platform = platform self.use_short_circuit_ops = use_short_circuit_ops self.unroll_loops = unroll_loops self.vsizes = vsizes self.signed_base_types = signed_base_types self.unsigned_base_types = unsigned_base_types self.integer_base_types = signed_base_types + unsigned_base_types self.float_types = float_types self.signed_types = signed_types self.unsigned_types = unsigned_types self.integer_types = integer_types self.builtin_types = builtin_types self.float_base_type_require = float_base_type_require self.basetype = basetype self.components = components self.vtype = vtype self.itype = itype self.uitype = uitype self.np_dtype = np_dtype self.vtype_component_adressing = vtype_component_adressing self.vtype_access = vtype_access self.mangle_vtype = mangle_vtype self.float_to_dec_str = float_to_dec_str self.float_to_hex_str = float_to_hex_str # pyopencl specifics self.intn = intn self.uintn = uintn self.simplen = simplen self.doublen = doublen self.typen = typen self.make_intn = make_intn self.make_uintn = make_uintn self.make_simplen = make_simplen self.make_doublen = make_doublen self.make_typen = make_typen if fbtype == "float": self.floatn = simplen self.make_floatn = make_simplen self.dtype = np.float32 elif fbtype == "double": self.floatn = doublen self.make_floatn = make_doublen self.dtype = np.float64 elif fbtype == "half": self.floatn = halfn self.make_floatn = make_halfn self.dtype = np.float16 else: raise ValueError(f"Unknown fbtype '{fbtype}'") self._ftype_build_options = self.get_precision_opts()
[docs] def ftype_build_options(self): return self._ftype_build_options
[docs] def device_has_ftype(self, device): dev_exts = device.extensions.split(" ") req = self.float_base_type_require[self.fbtype] return (req is None) or (req[0] in dev_exts)
[docs] def cl_requirements(self): return [self.float_base_type_require[self.fbtype]]
[docs] def opencl_version_greater(self, major, minor): (cl_major, cl_minor) = self.opencl_version() if cl_major < major: return False if (cl_major == major) and (cl_minor <= minor): return False return True
[docs] def opencl_version(self): assert self.device is not None sversion = self.device.version.strip() _regexp = r"OpenCL\s+(\d)\.(\d)" regexp = re.compile(_regexp) match = re.match(regexp, sversion) if not match: msg = "Could not extract OpenCL version from device returned version '{}' " msg += "and regular expression '{}'." msg = msg.format(sversion, _regexp) raise RuntimeError(msg) major = int(match.group(1)) minor = int(match.group(2)) return (major, minor)
[docs] def dtype_from_str(self, stype): stype = stype.replace("ftype", self.fbtype).replace("fbtype", self.fbtype) btype = basetype(stype) N = components(stype) return typen(btype, N)
[docs] def dump_expr(self, expr, symbol2vars=None, **printer_settings): """ Print sympy expression expr as OpenCL code. Sympy symbols may be replaced using symbol2vars dictionnary. This dumper uses OpenClTypeGen.dump for floats and quotients. See hysop.backend.device.opencl.opencl_printer.OpenClPrinter """ from hysop.backend.device.opencl.opencl_printer import OpenClPrinter printer = OpenClPrinter( typegen=self, symbol2vars=symbol2vars, **printer_settings ) return printer.doprint(expr)
[docs] def __repr__(self): """Used to hash in OpenClKernelAutotuner.autotuner_config_key()""" return "{}_{}_{}_{}_{}_{}".format( self.platform.name, self.device.name, self.fbtype, self.float_dump_mode, self.use_short_circuit_ops, self.unroll_loops, )
[docs] def get_precision_opts(self): """ Check if device is capable to work with given precision and returns build options considering this precision """ opts = [] # Precision supported fp32_rounding_flag = True if self.fbtype == "half": if self.device.half_fp_config <= 0: raise ValueError("Half precision is not supported on device.") Prec = "half" elif self.fbtype == "float": opts.append("-cl-single-precision-constant") prec = "single" elif self.fbtype == "double": if self.device.double_fp_config <= 0: raise ValueError("Double Precision is not supported on device") prec = "double" return opts